# sequential_grouping.py
from typing import List, Dict, Any, Optional
import numpy as np


class SequentialGroupingStrategy:
    """
    Top-k Sequential Grouping:
      1) Sort hyperparameters by importance (descending). Tie-break by name (ascending) for determinism.
      2) Slice the ordered list into consecutive groups of size `group_size` (last group may be smaller).
      3) Allocate trials across groups proportionally to group importance (sum of member importances).

    Edge handling in allocation:
      - If total_trials < #groups: give 1 to the top `total_trials` groups, 0 to the rest.
      - If total importance <= 0: even split; remainder to top group for determinism.
      - Otherwise: floor-based proportional allocation; remainder to top group.
    """

    def __init__(
        self,
        importance: Dict[str, float],
        search_space: Dict[str, Any],
        group_size: int = 2,
    ):
        """
        Parameters
        ----------
        importance : Dict[str, float]
            Mapping from parameter name to importance weight.
        search_space : Dict[str, Any]
            Full search space (kept for interface compatibility).
        group_size : int
            Size of each group (k). Must be >= 1.
        """
        if group_size < 1:
            raise ValueError("group_size must be >= 1")
        self.importance = dict(importance) if importance is not None else {}
        self.search_space = search_space
        self.group_size = int(group_size)

        self.grouped: List[List[str]] = self._group_by_top_k()
        self._group_weights: List[float] = [
            float(sum(self.importance[p] for p in group)) for group in self.grouped
        ]

    # ---------------------------
    # Public API
    # ---------------------------

    def get_ordered_groups(self) -> List[List[str]]:
        """Return groups ordered by total importance implicitly (top-k slicing preserves order)."""
        return self.grouped

    def allocate_budget(self, total_trials: int) -> List[int]:
        """
        Allocate `total_trials` across groups proportionally by group importance.

        Edge cases handled:
          - If total_trials < #groups: give 1 to top groups, 0 to the rest.
          - If total importance <= 0: even split; remainder to top group.
        """
        G = len(self._group_weights)
        if G == 0 or total_trials <= 0:
            return []

        # Fewer trials than groups -> importance-first, one per top group
        if total_trials < G:
            alloc = [0] * G
            for i in range(total_trials):
                alloc[i] = 1
            return alloc

        total_importance = float(sum(self._group_weights))

        # All-zero (or numerically non-positive) importance -> even split
        if total_importance <= 0.0:
            base = total_trials // G
            alloc = [base] * G
            alloc[0] += total_trials - sum(alloc)  # remainder to the top group
            return alloc

        # Proportional via floor, with "remainder to top group"
        raw = [total_trials * (w / total_importance) for w in self._group_weights]
        alloc = [max(1, int(np.floor(x))) for x in raw]

        remainder = total_trials - sum(alloc)
        if remainder != 0:
            alloc[0] += remainder  # simple + deterministic

        # Guard against negative due to numeric corner cases
        if alloc[0] < 0:
            alloc = [0] * G
            for i in range(min(total_trials, G)):
                alloc[i] = 1

        return alloc

    def export_group_schedule(self, total_trials: int) -> List[Dict[str, Any]]:
        """
        Returns a list of stages with:
            - 'group': list[str] of parameter names in this stage
            - 'budget': int, number of trials to run for this group
        """
        groups = self.get_ordered_groups()
        budgets = self.allocate_budget(total_trials)
        return [{"group": g, "budget": b} for g, b in zip(groups, budgets)]

    # ---------------------------
    # Internal helpers
    # ---------------------------

    def _group_by_top_k(self) -> List[List[str]]:
        """
        Sort parameters by importance desc, tie-break by name asc, then slice by group_size.
        """
        if not self.importance:
            return []

        # Stable, deterministic order: by importance desc, then name asc
        ordered = sorted(self.importance.items(), key=lambda kv: (-kv[1], kv[0]))
        names = [k for k, _ in ordered]

        # Slice into consecutive groups of size `group_size`
        k = self.group_size
        groups = [names[i:i + k] for i in range(0, len(names), k)]
        return groups
